import asyncio
import logging
import time
from typing import TYPE_CHECKING

from cdp_use.cdp.accessibility.commands import GetFullAXTreeReturns
from cdp_use.cdp.accessibility.types import AXNode
from cdp_use.cdp.dom.types import Node
from cdp_use.cdp.target import TargetID

from browser_use.dom.enhanced_snapshot import (
	REQUIRED_COMPUTED_STYLES,
	build_snapshot_lookup,
)
from browser_use.dom.serializer.serializer import DOMTreeSerializer
from browser_use.dom.views import (
	DOMRect,
	EnhancedAXNode,
	EnhancedAXProperty,
	EnhancedDOMTreeNode,
	NodeType,
	SerializedDOMState,
	TargetAllTrees,
)
from browser_use.observability import observe_debug
from browser_use.utils import create_task_with_error_handling

if TYPE_CHECKING:
	from browser_use.browser.session import BrowserSession

# Note: iframe limits are now configurable via BrowserProfile.max_iframes and BrowserProfile.max_iframe_depth


class DomService:
	"""
	Service for getting the DOM tree and other DOM-related information.

	Either browser or page must be provided.

	TODO: currently we start a new websocket connection PER STEP, we should definitely keep this persistent
	"""

	logger: logging.Logger

	def __init__(
		self,
		browser_session: 'BrowserSession',
		logger: logging.Logger | None = None,
		cross_origin_iframes: bool = False,
		paint_order_filtering: bool = True,
		max_iframes: int = 100,
		max_iframe_depth: int = 5,
	):
		self.browser_session = browser_session
		self.logger = logger or browser_session.logger
		self.cross_origin_iframes = cross_origin_iframes
		self.paint_order_filtering = paint_order_filtering
		self.max_iframes = max_iframes
		self.max_iframe_depth = max_iframe_depth

	async def __aenter__(self):
		return self

	async def __aexit__(self, exc_type, exc_value, traceback):
		pass  # no need to cleanup anything, browser_session auto handles cleaning up session cache

	def _build_enhanced_ax_node(self, ax_node: AXNode) -> EnhancedAXNode:
		properties: list[EnhancedAXProperty] | None = None
		if 'properties' in ax_node and ax_node['properties']:
			properties = []
			for property in ax_node['properties']:
				try:
					# test whether property name can go into the enum (sometimes Chrome returns some random properties)
					properties.append(
						EnhancedAXProperty(
							name=property['name'],
							value=property.get('value', {}).get('value', None),
							# related_nodes=[],  # TODO: add related nodes
						)
					)
				except ValueError:
					pass

		enhanced_ax_node = EnhancedAXNode(
			ax_node_id=ax_node['nodeId'],
			ignored=ax_node['ignored'],
			role=ax_node.get('role', {}).get('value', None),
			name=ax_node.get('name', {}).get('value', None),
			description=ax_node.get('description', {}).get('value', None),
			properties=properties,
			child_ids=ax_node.get('childIds', []) if ax_node.get('childIds') else None,
		)
		return enhanced_ax_node

	async def _get_viewport_ratio(self, target_id: TargetID) -> float:
		"""Get viewport dimensions, device pixel ratio, and scroll position using CDP."""
		cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)

		try:
			# Get the layout metrics which includes the visual viewport
			metrics = await cdp_session.cdp_client.send.Page.getLayoutMetrics(session_id=cdp_session.session_id)

			visual_viewport = metrics.get('visualViewport', {})

			# IMPORTANT: Use CSS viewport instead of device pixel viewport
			# This fixes the coordinate mismatch on high-DPI displays
			css_visual_viewport = metrics.get('cssVisualViewport', {})
			css_layout_viewport = metrics.get('cssLayoutViewport', {})

			# Use CSS pixels (what JavaScript sees) instead of device pixels
			width = css_visual_viewport.get('clientWidth', css_layout_viewport.get('clientWidth', 1920.0))

			# Calculate device pixel ratio
			device_width = visual_viewport.get('clientWidth', width)
			css_width = css_visual_viewport.get('clientWidth', width)
			device_pixel_ratio = device_width / css_width if css_width > 0 else 1.0

			return float(device_pixel_ratio)
		except Exception as e:
			self.logger.debug(f'Viewport size detection failed: {e}')
			# Fallback to default viewport size
			return 1.0

	@classmethod
	def is_element_visible_according_to_all_parents(
		cls, node: EnhancedDOMTreeNode, html_frames: list[EnhancedDOMTreeNode]
	) -> bool:
		"""Check if the element is visible according to all its parent HTML frames."""

		if not node.snapshot_node:
			return False

		computed_styles = node.snapshot_node.computed_styles or {}

		display = computed_styles.get('display', '').lower()
		visibility = computed_styles.get('visibility', '').lower()
		opacity = computed_styles.get('opacity', '1')

		if display == 'none' or visibility == 'hidden':
			return False

		try:
			if float(opacity) <= 0:
				return False
		except (ValueError, TypeError):
			pass

		# Start with the element's local bounds (in its own frame's coordinate system)
		current_bounds = node.snapshot_node.bounds

		if not current_bounds:
			return False  # If there are no bounds, the element is not visible

		"""
		Reverse iterate through the html frames (that can be either iframe or document -> if it's a document frame compare if the current bounds interest with it (taking scroll into account) otherwise move the current bounds by the iframe offset)
		"""
		for frame in reversed(html_frames):
			if (
				frame.node_type == NodeType.ELEMENT_NODE
				and (frame.node_name.upper() == 'IFRAME' or frame.node_name.upper() == 'FRAME')
				and frame.snapshot_node
				and frame.snapshot_node.bounds
			):
				iframe_bounds = frame.snapshot_node.bounds

				# negate the values added in `_construct_enhanced_node`
				current_bounds.x += iframe_bounds.x
				current_bounds.y += iframe_bounds.y

			if (
				frame.node_type == NodeType.ELEMENT_NODE
				and frame.node_name == 'HTML'
				and frame.snapshot_node
				and frame.snapshot_node.scrollRects
				and frame.snapshot_node.clientRects
			):
				# For iframe content, we need to check visibility within the iframe's viewport
				# The scrollRects represent the current scroll position
				# The clientRects represent the viewport size
				# Elements are visible if they fall within the viewport after accounting for scroll

				# The viewport of the frame (what's actually visible)
				viewport_left = 0  # Viewport always starts at 0 in frame coordinates
				viewport_top = 0
				viewport_right = frame.snapshot_node.clientRects.width
				viewport_bottom = frame.snapshot_node.clientRects.height

				# Adjust element bounds by the scroll offset to get position relative to viewport
				# When scrolled down, scrollRects.y is positive, so we subtract it from element's y
				adjusted_x = current_bounds.x - frame.snapshot_node.scrollRects.x
				adjusted_y = current_bounds.y - frame.snapshot_node.scrollRects.y

				frame_intersects = (
					adjusted_x < viewport_right
					and adjusted_x + current_bounds.width > viewport_left
					and adjusted_y < viewport_bottom + 1000
					and adjusted_y + current_bounds.height > viewport_top - 1000
				)

				if not frame_intersects:
					return False

				# Keep the original coordinate adjustment to maintain consistency
				# This adjustment is needed for proper coordinate transformation
				current_bounds.x -= frame.snapshot_node.scrollRects.x
				current_bounds.y -= frame.snapshot_node.scrollRects.y

		# If we reach here, element is visible in main viewport and all containing iframes
		return True

	async def _get_ax_tree_for_all_frames(self, target_id: TargetID) -> GetFullAXTreeReturns:
		"""Recursively collect all frames and merge their accessibility trees into a single array."""

		cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)
		frame_tree = await cdp_session.cdp_client.send.Page.getFrameTree(session_id=cdp_session.session_id)

		def collect_all_frame_ids(frame_tree_node) -> list[str]:
			"""Recursively collect all frame IDs from the frame tree."""
			frame_ids = [frame_tree_node['frame']['id']]

			if 'childFrames' in frame_tree_node and frame_tree_node['childFrames']:
				for child_frame in frame_tree_node['childFrames']:
					frame_ids.extend(collect_all_frame_ids(child_frame))

			return frame_ids

		# Collect all frame IDs recursively
		all_frame_ids = collect_all_frame_ids(frame_tree['frameTree'])

		# Get accessibility tree for each frame
		ax_tree_requests = []
		for frame_id in all_frame_ids:
			ax_tree_request = cdp_session.cdp_client.send.Accessibility.getFullAXTree(
				params={'frameId': frame_id}, session_id=cdp_session.session_id
			)
			ax_tree_requests.append(ax_tree_request)

		# Wait for all requests to complete
		ax_trees = await asyncio.gather(*ax_tree_requests)

		# Merge all AX nodes into a single array
		merged_nodes: list[AXNode] = []
		for ax_tree in ax_trees:
			merged_nodes.extend(ax_tree['nodes'])

		return {'nodes': merged_nodes}

	async def _get_all_trees(self, target_id: TargetID) -> TargetAllTrees:
		cdp_session = await self.browser_session.get_or_create_cdp_session(target_id=target_id, focus=False)

		# Wait for the page to be ready first
		try:
			ready_state = await cdp_session.cdp_client.send.Runtime.evaluate(
				params={'expression': 'document.readyState'}, session_id=cdp_session.session_id
			)
		except Exception as e:
			pass  # Page might not be ready yet
		# DEBUG: Log before capturing snapshot
		self.logger.debug(f'🔍 DEBUG: Capturing DOM snapshot for target {target_id}')

		# Get actual scroll positions for all iframes before capturing snapshot
		start_iframe_scroll = time.time()
		iframe_scroll_positions = {}
		try:
			scroll_result = await cdp_session.cdp_client.send.Runtime.evaluate(
				params={
					'expression': """
					(() => {
						const scrollData = {};
						const iframes = document.querySelectorAll('iframe');
						iframes.forEach((iframe, index) => {
							try {
								const doc = iframe.contentDocument || iframe.contentWindow.document;
								if (doc) {
									scrollData[index] = {
										scrollTop: doc.documentElement.scrollTop || doc.body.scrollTop || 0,
										scrollLeft: doc.documentElement.scrollLeft || doc.body.scrollLeft || 0
									};
								}
							} catch (e) {
								// Cross-origin iframe, can't access
							}
						});
						return scrollData;
					})()
					""",
					'returnByValue': True,
				},
				session_id=cdp_session.session_id,
			)
			if scroll_result and 'result' in scroll_result and 'value' in scroll_result['result']:
				iframe_scroll_positions = scroll_result['result']['value']
				for idx, scroll_data in iframe_scroll_positions.items():
					self.logger.debug(
						f'🔍 DEBUG: Iframe {idx} actual scroll position - scrollTop={scroll_data.get("scrollTop", 0)}, scrollLeft={scroll_data.get("scrollLeft", 0)}'
					)
		except Exception as e:
			self.logger.debug(f'Failed to get iframe scroll positions: {e}')
		iframe_scroll_ms = (time.time() - start_iframe_scroll) * 1000

		# Define CDP request factories to avoid duplication
		def create_snapshot_request():
			return cdp_session.cdp_client.send.DOMSnapshot.captureSnapshot(
				params={
					'computedStyles': REQUIRED_COMPUTED_STYLES,
					'includePaintOrder': True,
					'includeDOMRects': True,
					'includeBlendedBackgroundColors': False,
					'includeTextColorOpacities': False,
				},
				session_id=cdp_session.session_id,
			)

		def create_dom_tree_request():
			return cdp_session.cdp_client.send.DOM.getDocument(
				params={'depth': -1, 'pierce': True}, session_id=cdp_session.session_id
			)

		start_cdp_calls = time.time()

		# Create initial tasks
		tasks = {
			'snapshot': create_task_with_error_handling(create_snapshot_request(), name='get_snapshot'),
			'dom_tree': create_task_with_error_handling(create_dom_tree_request(), name='get_dom_tree'),
			'ax_tree': create_task_with_error_handling(self._get_ax_tree_for_all_frames(target_id), name='get_ax_tree'),
			'device_pixel_ratio': create_task_with_error_handling(self._get_viewport_ratio(target_id), name='get_viewport_ratio'),
		}

		# Wait for all tasks with timeout
		done, pending = await asyncio.wait(tasks.values(), timeout=10.0)

		# Retry any failed or timed out tasks
		if pending:
			for task in pending:
				task.cancel()

			# Retry mapping for pending tasks
			retry_map = {
				tasks['snapshot']: lambda: create_task_with_error_handling(create_snapshot_request(), name='get_snapshot_retry'),
				tasks['dom_tree']: lambda: create_task_with_error_handling(create_dom_tree_request(), name='get_dom_tree_retry'),
				tasks['ax_tree']: lambda: create_task_with_error_handling(
					self._get_ax_tree_for_all_frames(target_id), name='get_ax_tree_retry'
				),
				tasks['device_pixel_ratio']: lambda: create_task_with_error_handling(
					self._get_viewport_ratio(target_id), name='get_viewport_ratio_retry'
				),
			}

			# Create new tasks only for the ones that didn't complete
			for key, task in tasks.items():
				if task in pending and task in retry_map:
					tasks[key] = retry_map[task]()

			# Wait again with shorter timeout
			done2, pending2 = await asyncio.wait([t for t in tasks.values() if not t.done()], timeout=2.0)

			if pending2:
				for task in pending2:
					task.cancel()

		# Extract results, tracking which ones failed
		results = {}
		failed = []
		for key, task in tasks.items():
			if task.done() and not task.cancelled():
				try:
					results[key] = task.result()
				except Exception as e:
					self.logger.warning(f'CDP request {key} failed with exception: {e}')
					failed.append(key)
			else:
				self.logger.warning(f'CDP request {key} timed out')
				failed.append(key)

		# If any required tasks failed, raise an exception
		if failed:
			raise TimeoutError(f'CDP requests failed or timed out: {", ".join(failed)}')

		snapshot = results['snapshot']
		dom_tree = results['dom_tree']
		ax_tree = results['ax_tree']
		device_pixel_ratio = results['device_pixel_ratio']
		end_cdp_calls = time.time()
		cdp_calls_ms = (end_cdp_calls - start_cdp_calls) * 1000

		# Calculate total time for _get_all_trees and overhead
		start_snapshot_processing = time.time()

		# DEBUG: Log snapshot info and limit documents to prevent explosion
		if snapshot and 'documents' in snapshot:
			original_doc_count = len(snapshot['documents'])
			# Limit to max_iframes documents to prevent iframe explosion
			if original_doc_count > self.max_iframes:
				self.logger.warning(
					f'⚠️ Limiting processing of {original_doc_count} iframes on page to only first {self.max_iframes} to prevent crashes!'
				)
				snapshot['documents'] = snapshot['documents'][: self.max_iframes]

			total_nodes = sum(len(doc.get('nodes', [])) for doc in snapshot['documents'])
			self.logger.debug(f'🔍 DEBUG: Snapshot contains {len(snapshot["documents"])} frames with {total_nodes} total nodes')
			# Log iframe-specific info
			for doc_idx, doc in enumerate(snapshot['documents']):
				if doc_idx > 0:  # Not the main document
					self.logger.debug(
						f'🔍 DEBUG: Iframe #{doc_idx} {doc.get("frameId", "no-frame-id")} {doc.get("url", "no-url")} has {len(doc.get("nodes", []))} nodes'
					)

		snapshot_processing_ms = (time.time() - start_snapshot_processing) * 1000

		# Return with detailed timing breakdown
		return TargetAllTrees(
			snapshot=snapshot,
			dom_tree=dom_tree,
			ax_tree=ax_tree,
			device_pixel_ratio=device_pixel_ratio,
			cdp_timing={
				'iframe_scroll_detection_ms': iframe_scroll_ms,
				'cdp_parallel_calls_ms': cdp_calls_ms,
				'snapshot_processing_ms': snapshot_processing_ms,
			},
		)

	@observe_debug(ignore_input=True, ignore_output=True, name='get_dom_tree')
	async def get_dom_tree(
		self,
		target_id: TargetID,
		all_frames: dict | None = None,
		initial_html_frames: list[EnhancedDOMTreeNode] | None = None,
		initial_total_frame_offset: DOMRect | None = None,
		iframe_depth: int = 0,
	) -> tuple[EnhancedDOMTreeNode, dict[str, float]]:
		"""Get the DOM tree for a specific target.

		Args:
			target_id: Target ID of the page to get the DOM tree for.
			all_frames: Pre-fetched frame hierarchy to avoid redundant CDP calls (optional, lazy fetch if None)
			initial_html_frames: List of HTML frame nodes encountered so far
			initial_total_frame_offset: Accumulated coordinate offset
			iframe_depth: Current depth of iframe nesting to prevent infinite recursion

		Returns:
			Tuple of (enhanced_dom_tree_node, timing_info)
		"""
		timing_info: dict[str, float] = {}
		timing_start_total = time.time()

		# Get all trees from CDP (snapshot, DOM, AX, viewport ratio)
		start_get_trees = time.time()
		trees = await self._get_all_trees(target_id)
		get_trees_ms = (time.time() - start_get_trees) * 1000
		timing_info.update(trees.cdp_timing)
		timing_info['get_all_trees_total_ms'] = get_trees_ms

		dom_tree = trees.dom_tree
		ax_tree = trees.ax_tree
		snapshot = trees.snapshot
		device_pixel_ratio = trees.device_pixel_ratio

		# Build AX tree lookup
		start_ax = time.time()
		ax_tree_lookup: dict[int, AXNode] = {
			ax_node['backendDOMNodeId']: ax_node for ax_node in ax_tree['nodes'] if 'backendDOMNodeId' in ax_node
		}
		timing_info['build_ax_lookup_ms'] = (time.time() - start_ax) * 1000

		enhanced_dom_tree_node_lookup: dict[int, EnhancedDOMTreeNode] = {}
		""" NodeId (NOT backend node id) -> enhanced dom tree node"""  # way to get the parent/content node

		# Parse snapshot data with everything calculated upfront
		start_snapshot = time.time()
		snapshot_lookup = build_snapshot_lookup(snapshot, device_pixel_ratio)
		timing_info['build_snapshot_lookup_ms'] = (time.time() - start_snapshot) * 1000

		async def _construct_enhanced_node(
			node: Node,
			html_frames: list[EnhancedDOMTreeNode] | None,
			total_frame_offset: DOMRect | None,
			all_frames: dict | None,
		) -> EnhancedDOMTreeNode:
			"""
			Recursively construct enhanced DOM tree nodes.

			Args:
				node: The DOM node to construct
				html_frames: List of HTML frame nodes encountered so far
				total_frame_offset: Accumulated coordinate translation from parent iframes (includes scroll corrections)
				all_frames: Pre-fetched frame hierarchy to avoid redundant CDP calls
			"""

			# Initialize lists if not provided
			if html_frames is None:
				html_frames = []

			# to get rid of the pointer references
			if total_frame_offset is None:
				total_frame_offset = DOMRect(x=0.0, y=0.0, width=0.0, height=0.0)
			else:
				total_frame_offset = DOMRect(
					total_frame_offset.x, total_frame_offset.y, total_frame_offset.width, total_frame_offset.height
				)

			# memoize the mf (I don't know if some nodes are duplicated)
			if node['nodeId'] in enhanced_dom_tree_node_lookup:
				return enhanced_dom_tree_node_lookup[node['nodeId']]

			ax_node = ax_tree_lookup.get(node['backendNodeId'])
			if ax_node:
				enhanced_ax_node = self._build_enhanced_ax_node(ax_node)
			else:
				enhanced_ax_node = None

			# To make attributes more readable
			attributes: dict[str, str] | None = None
			if 'attributes' in node and node['attributes']:
				attributes = {}
				for i in range(0, len(node['attributes']), 2):
					attributes[node['attributes'][i]] = node['attributes'][i + 1]

			shadow_root_type = None
			if 'shadowRootType' in node and node['shadowRootType']:
				try:
					shadow_root_type = node['shadowRootType']
				except ValueError:
					pass

			# Get snapshot data and calculate absolute position
			snapshot_data = snapshot_lookup.get(node['backendNodeId'], None)
			absolute_position = None
			if snapshot_data and snapshot_data.bounds:
				absolute_position = DOMRect(
					x=snapshot_data.bounds.x + total_frame_offset.x,
					y=snapshot_data.bounds.y + total_frame_offset.y,
					width=snapshot_data.bounds.width,
					height=snapshot_data.bounds.height,
				)

			try:
				session = await self.browser_session.get_or_create_cdp_session(target_id, focus=False)
				session_id = session.session_id
			except ValueError:
				# Target may have detached during DOM construction
				session_id = None

			dom_tree_node = EnhancedDOMTreeNode(
				node_id=node['nodeId'],
				backend_node_id=node['backendNodeId'],
				node_type=NodeType(node['nodeType']),
				node_name=node['nodeName'],
				node_value=node['nodeValue'],
				attributes=attributes or {},
				is_scrollable=node.get('isScrollable', None),
				frame_id=node.get('frameId', None),
				session_id=session_id,
				target_id=target_id,
				content_document=None,
				shadow_root_type=shadow_root_type,
				shadow_roots=None,
				parent_node=None,
				children_nodes=None,
				ax_node=enhanced_ax_node,
				snapshot_node=snapshot_data,
				is_visible=None,
				absolute_position=absolute_position,
			)

			enhanced_dom_tree_node_lookup[node['nodeId']] = dom_tree_node

			if 'parentId' in node and node['parentId']:
				dom_tree_node.parent_node = enhanced_dom_tree_node_lookup[
					node['parentId']
				]  # parents should always be in the lookup

			# Check if this is an HTML frame node and add it to the list
			updated_html_frames = html_frames.copy()
			if node['nodeType'] == NodeType.ELEMENT_NODE.value and node['nodeName'] == 'HTML' and node.get('frameId') is not None:
				updated_html_frames.append(dom_tree_node)

				# and adjust the total frame offset by scroll
				if snapshot_data and snapshot_data.scrollRects:
					total_frame_offset.x -= snapshot_data.scrollRects.x
					total_frame_offset.y -= snapshot_data.scrollRects.y
					# DEBUG: Log iframe scroll information
					self.logger.debug(
						f'🔍 DEBUG: HTML frame scroll - scrollY={snapshot_data.scrollRects.y}, scrollX={snapshot_data.scrollRects.x}, frameId={node.get("frameId")}, nodeId={node["nodeId"]}'
					)

			# Calculate new iframe offset for content documents, accounting for iframe scroll
			if (
				(node['nodeName'].upper() == 'IFRAME' or node['nodeName'].upper() == 'FRAME')
				and snapshot_data
				and snapshot_data.bounds
			):
				if snapshot_data.bounds:
					updated_html_frames.append(dom_tree_node)

					total_frame_offset.x += snapshot_data.bounds.x
					total_frame_offset.y += snapshot_data.bounds.y

			if 'contentDocument' in node and node['contentDocument']:
				dom_tree_node.content_document = await _construct_enhanced_node(
					node['contentDocument'], updated_html_frames, total_frame_offset, all_frames
				)
				dom_tree_node.content_document.parent_node = dom_tree_node
				# forcefully set the parent node to the content document node (helps traverse the tree)

			if 'shadowRoots' in node and node['shadowRoots']:
				dom_tree_node.shadow_roots = []
				for shadow_root in node['shadowRoots']:
					shadow_root_node = await _construct_enhanced_node(
						shadow_root, updated_html_frames, total_frame_offset, all_frames
					)
					# forcefully set the parent node to the shadow root node (helps traverse the tree)
					shadow_root_node.parent_node = dom_tree_node
					dom_tree_node.shadow_roots.append(shadow_root_node)

			if 'children' in node and node['children']:
				dom_tree_node.children_nodes = []
				# Build set of shadow root node IDs to filter them out from children
				shadow_root_node_ids = set()
				if 'shadowRoots' in node and node['shadowRoots']:
					for shadow_root in node['shadowRoots']:
						shadow_root_node_ids.add(shadow_root['nodeId'])

				for child in node['children']:
					# Skip shadow roots - they should only be in shadow_roots list
					if child['nodeId'] in shadow_root_node_ids:
						continue
					dom_tree_node.children_nodes.append(
						await _construct_enhanced_node(child, updated_html_frames, total_frame_offset, all_frames)
					)

			# Set visibility using the collected HTML frames
			dom_tree_node.is_visible = self.is_element_visible_according_to_all_parents(dom_tree_node, updated_html_frames)

			# DEBUG: Log visibility info for form elements in iframes
			if dom_tree_node.tag_name and dom_tree_node.tag_name.upper() in ['INPUT', 'SELECT', 'TEXTAREA', 'LABEL']:
				attrs = dom_tree_node.attributes or {}
				elem_id = attrs.get('id', '')
				elem_name = attrs.get('name', '')
				if (
					'city' in elem_id.lower()
					or 'city' in elem_name.lower()
					or 'state' in elem_id.lower()
					or 'state' in elem_name.lower()
					or 'zip' in elem_id.lower()
					or 'zip' in elem_name.lower()
				):
					self.logger.debug(
						f"🔍 DEBUG: Form element {dom_tree_node.tag_name} id='{elem_id}' name='{elem_name}' - visible={dom_tree_node.is_visible}, bounds={dom_tree_node.snapshot_node.bounds if dom_tree_node.snapshot_node else 'NO_SNAPSHOT'}"
					)

			# handle cross origin iframe (just recursively call the main function with the proper target if it exists in iframes)
			# only do this if the iframe is visible (otherwise it's not worth it)

			if (
				# TODO: hacky way to disable cross origin iframes for now
				self.cross_origin_iframes and node['nodeName'].upper() == 'IFRAME' and node.get('contentDocument', None) is None
			):  # None meaning there is no content
				# Check iframe depth to prevent infinite recursion
				if iframe_depth >= self.max_iframe_depth:
					self.logger.debug(
						f'Skipping iframe at depth {iframe_depth} to prevent infinite recursion (max depth: {self.max_iframe_depth})'
					)
				else:
					# Check if iframe is visible and large enough (>= 50px in both dimensions)
					should_process_iframe = False

					# First check if the iframe element itself is visible
					if dom_tree_node.is_visible:
						# Check iframe dimensions
						if dom_tree_node.snapshot_node and dom_tree_node.snapshot_node.bounds:
							bounds = dom_tree_node.snapshot_node.bounds
							width = bounds.width
							height = bounds.height

							# Only process if iframe is at least 50px in both dimensions
							if width >= 50 and height >= 50:
								should_process_iframe = True
								self.logger.debug(f'Processing cross-origin iframe: visible=True, width={width}, height={height}')
							else:
								self.logger.debug(
									f'Skipping small cross-origin iframe: width={width}, height={height} (needs >= 50px)'
								)
						else:
							self.logger.debug('Skipping cross-origin iframe: no bounds available')
					else:
						self.logger.debug('Skipping invisible cross-origin iframe')

					if should_process_iframe:
						# Lazy fetch all_frames only when actually needed (for cross-origin iframes)
						if all_frames is None:
							all_frames, _ = await self.browser_session.get_all_frames()

						# Use pre-fetched all_frames to find the iframe's target (no redundant CDP call)
						frame_id = node.get('frameId', None)
						if frame_id:
							frame_info = all_frames.get(frame_id)
							iframe_document_target = None
							if frame_info and frame_info.get('frameTargetId'):
								iframe_target_id = frame_info['frameTargetId']
								iframe_target = self.browser_session.session_manager.get_target(iframe_target_id)
								if iframe_target:
									iframe_document_target = {
										'targetId': iframe_target.target_id,
										'url': iframe_target.url,
										'title': iframe_target.title,
										'type': iframe_target.target_type,
									}
						else:
							iframe_document_target = None
						# if target actually exists in one of the frames, just recursively build the dom tree for it
						if iframe_document_target:
							self.logger.debug(
								f'Getting content document for iframe {node.get("frameId", None)} at depth {iframe_depth + 1}'
							)
							content_document, _ = await self.get_dom_tree(
								target_id=iframe_document_target['targetId'],
								all_frames=all_frames,
								# TODO: experiment with this values -> not sure whether the whole cross origin iframe should be ALWAYS included as soon as some part of it is visible or not.
								# Current config: if the cross origin iframe is AT ALL visible, then just include everything inside of it!
								# initial_html_frames=updated_html_frames,
								initial_total_frame_offset=total_frame_offset,
								iframe_depth=iframe_depth + 1,
							)

							dom_tree_node.content_document = content_document
							dom_tree_node.content_document.parent_node = dom_tree_node

			return dom_tree_node

		# Build enhanced DOM tree recursively
		# Note: all_frames stays None and will be lazily fetched inside _construct_enhanced_node
		# only if/when a cross-origin iframe is encountered
		start_construct = time.time()
		enhanced_dom_tree_node = await _construct_enhanced_node(
			dom_tree['root'], initial_html_frames, initial_total_frame_offset, all_frames
		)
		timing_info['construct_enhanced_tree_ms'] = (time.time() - start_construct) * 1000

		# Calculate total time for get_dom_tree
		total_get_dom_tree_ms = (time.time() - timing_start_total) * 1000
		timing_info['get_dom_tree_total_ms'] = total_get_dom_tree_ms

		# Calculate overhead in get_dom_tree (time not accounted for by sub-operations)
		tracked_sub_operations_ms = (
			timing_info.get('get_all_trees_total_ms', 0)
			+ timing_info.get('build_ax_lookup_ms', 0)
			+ timing_info.get('build_snapshot_lookup_ms', 0)
			+ timing_info.get('construct_enhanced_tree_ms', 0)
		)
		get_dom_tree_overhead_ms = total_get_dom_tree_ms - tracked_sub_operations_ms
		if get_dom_tree_overhead_ms > 0.1:
			timing_info['get_dom_tree_overhead_ms'] = get_dom_tree_overhead_ms

		return enhanced_dom_tree_node, timing_info

	@observe_debug(ignore_input=True, ignore_output=True, name='get_serialized_dom_tree')
	async def get_serialized_dom_tree(
		self, previous_cached_state: SerializedDOMState | None = None
	) -> tuple[SerializedDOMState, EnhancedDOMTreeNode, dict[str, float]]:
		"""Get the serialized DOM tree representation for LLM consumption.

		Returns:
			Tuple of (serialized_dom_state, enhanced_dom_tree_root, timing_info)
		"""
		timing_info: dict[str, float] = {}
		start_total = time.time()

		# Use current target (None means use current)
		assert self.browser_session.agent_focus_target_id is not None

		session_id = self.browser_session.id

		# Build DOM tree (includes CDP calls for snapshot, DOM, AX tree)
		# Note: all_frames is fetched lazily inside get_dom_tree only if cross-origin iframes need it
		enhanced_dom_tree, dom_tree_timing = await self.get_dom_tree(
			target_id=self.browser_session.agent_focus_target_id,
			all_frames=None,  # Lazy - will fetch if needed
		)

		# Add sub-timings from DOM tree construction
		timing_info.update(dom_tree_timing)

		# Serialize DOM tree for LLM
		start_serialize = time.time()

		serialized_dom_state, serializer_timing = DOMTreeSerializer(
			enhanced_dom_tree, previous_cached_state, paint_order_filtering=self.paint_order_filtering, session_id=session_id
		).serialize_accessible_elements()
		total_serialization_ms = (time.time() - start_serialize) * 1000

		# Add serializer sub-timings (convert to ms)
		for key, value in serializer_timing.items():
			timing_info[f'{key}_ms'] = value * 1000

		# Calculate untracked time in serialization
		tracked_serialization_ms = sum(value * 1000 for value in serializer_timing.values())
		serialization_overhead_ms = total_serialization_ms - tracked_serialization_ms
		if serialization_overhead_ms > 0.1:  # Only log if significant
			timing_info['serialization_overhead_ms'] = serialization_overhead_ms

		# Calculate total time for get_serialized_dom_tree
		total_get_serialized_dom_tree_ms = (time.time() - start_total) * 1000
		timing_info['get_serialized_dom_tree_total_ms'] = total_get_serialized_dom_tree_ms

		# Calculate overhead in get_serialized_dom_tree (time not accounted for)
		tracked_major_operations_ms = timing_info.get('get_dom_tree_total_ms', 0) + total_serialization_ms
		get_serialized_overhead_ms = total_get_serialized_dom_tree_ms - tracked_major_operations_ms
		if get_serialized_overhead_ms > 0.1:
			timing_info['get_serialized_dom_tree_overhead_ms'] = get_serialized_overhead_ms

		return serialized_dom_state, enhanced_dom_tree, timing_info

	@staticmethod
	def detect_pagination_buttons(selector_map: dict[int, EnhancedDOMTreeNode]) -> list[dict[str, str | int | bool]]:
		"""Detect pagination buttons from the selector map.

		Args:
			selector_map: Map of element indices to EnhancedDOMTreeNode

		Returns:
			List of pagination button information dicts with:
			- button_type: 'next', 'prev', 'first', 'last', 'page_number'
			- backend_node_id: Backend node ID for clicking
			- text: Button text/label
			- selector: XPath selector
			- is_disabled: Whether the button appears disabled
		"""
		pagination_buttons: list[dict[str, str | int | bool]] = []

		# Common pagination patterns to look for
		next_patterns = ['next', '>', '»', '→', 'siguiente', 'suivant', 'weiter', 'volgende']
		prev_patterns = ['prev', 'previous', '<', '«', '←', 'anterior', 'précédent', 'zurück', 'vorige']
		first_patterns = ['first', '⇤', '«', 'primera', 'première', 'erste', 'eerste']
		last_patterns = ['last', '⇥', '»', 'última', 'dernier', 'letzte', 'laatste']

		for index, node in selector_map.items():
			# Skip non-clickable elements
			if not node.snapshot_node or not node.snapshot_node.is_clickable:
				continue

			# Get element text and attributes
			text = node.get_all_children_text().lower().strip()
			aria_label = node.attributes.get('aria-label', '').lower()
			title = node.attributes.get('title', '').lower()
			class_name = node.attributes.get('class', '').lower()
			role = node.attributes.get('role', '').lower()

			# Combine all text sources for pattern matching
			all_text = f'{text} {aria_label} {title} {class_name}'.strip()

			# Check if it's disabled
			is_disabled = (
				node.attributes.get('disabled') == 'true'
				or node.attributes.get('aria-disabled') == 'true'
				or 'disabled' in class_name
			)

			button_type: str | None = None

			# Check for next button
			if any(pattern in all_text for pattern in next_patterns):
				button_type = 'next'
			# Check for previous button
			elif any(pattern in all_text for pattern in prev_patterns):
				button_type = 'prev'
			# Check for first button
			elif any(pattern in all_text for pattern in first_patterns):
				button_type = 'first'
			# Check for last button
			elif any(pattern in all_text for pattern in last_patterns):
				button_type = 'last'
			# Check for numeric page buttons (single or double digit)
			elif text.isdigit() and len(text) <= 2 and role in ['button', 'link', '']:
				button_type = 'page_number'

			if button_type:
				pagination_buttons.append(
					{
						'button_type': button_type,
						'backend_node_id': index,
						'text': node.get_all_children_text().strip() or aria_label or title,
						'selector': node.xpath,
						'is_disabled': is_disabled,
					}
				)

		return pagination_buttons
